GAN

GAN

背景介绍

  GAN(Generative Adversarial Networks, 生成式对抗网络):是GAN类型网络的初代版本,据说是Ian Goodfellow在2014年喝了一杯啤酒之后,在梦中产生的想法,我不禁感叹,大佬就是大佬啊,虽然现在这是最简单的生成式对抗网络模型,其效果也被很多模型超越,但是它的思想值得我们学习。

gan

GAN特点

  只采用了全连接层和ReLU激活函数,没有使用卷积层对图像进行处理
  生成器的输出使用tanh,产生[-1, 1]的图像,判别器的输出使用sigmoid,产生真或者假的逻辑值

GAN图像分析

generator
discriminator

TensorFlow2.0实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import numpy as np
import cv2 as cv
from functools import reduce
import tensorflow as tf
import tensorflow.keras as keras


def compose(*funcs):
if funcs:
return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs)
else:
raise ValueError('Composition of empty sequence not supported.')


def generator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(keras.layers.Dense(256, activation='relu', name='dense_relu1'),
keras.layers.BatchNormalization(momentum=0.8, name='bn1'),
keras.layers.Dense(512, activation='relu', name='dense_relu2'),
keras.layers.BatchNormalization(momentum=0.8, name='bn2'),
keras.layers.Dense(1024, activation='relu', name='dense_relu3'),
keras.layers.BatchNormalization(momentum=0.8, name='bn3'),
keras.layers.Dense(784, activation='tanh', name='dense_tanh'),
keras.layers.Reshape((28, 28, 1)))(x)

model = keras.Model(input_tensor, x, name='GAN-Generator')

return model


def discriminator(input_shape):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = compose(keras.layers.Flatten(name='flatten'),
keras.layers.Dense(512, activation='relu', name='dense_relu1'),
keras.layers.Dense(256, activation='relu', name='dense_relu2'),
keras.layers.Dense(1, activation='sigmoid', name='dense_sigmoid'))(x)

model = keras.Model(input_tensor, x, name='GAN-Discriminator')

return model


def gan(input_shape, model_g, model_d):
input_tensor = keras.layers.Input(input_shape, name='input')
x = input_tensor

x = model_g(x)
model_d.trainable = False
x = model_d(x)

model = keras.Model(input_tensor, x, name='GAN')

return model


def save_picture(image, save_path, picture_num):
image = ((image + 1) * 127.5).astype(np.uint8)
image = np.concatenate([image[i * picture_num:(i + 1) * picture_num] for i in range(picture_num)], axis=2)
image = np.concatenate([image[i] for i in range(picture_num)], axis=0)
cv.imwrite(save_path, image)


if __name__ == '__main__':
(x, _), (_, _) = keras.datasets.mnist.load_data()
batch_size = 256
epochs = 20
tf.random.set_seed(22)
save_path = r'.\gan'
if not os.path.exists(save_path):
os.makedirs(save_path)

x = x[..., np.newaxis].astype(np.float) / 127.5 - 1
x = tf.data.Dataset.from_tensor_slices(x).batch(batch_size)

optimizer = keras.optimizers.Adam(0.0002, 0.5)
loss = keras.losses.BinaryCrossentropy()

real_dacc = keras.metrics.BinaryAccuracy()
fake_dacc = keras.metrics.BinaryAccuracy()
gacc = keras.metrics.BinaryAccuracy()

model_d = discriminator(input_shape=(28, 28, 1))
model_d.compile(optimizer=optimizer, loss='binary_crossentropy')

model_g = generator(input_shape=(100,))

model_g.build(input_shape=(100,))
model_g.summary()
keras.utils.plot_model(model_g, 'GAN-generator.png', show_shapes=True, show_layer_names=True)

model_d.build(input_shape=(28, 28, 1))
model_d.summary()
keras.utils.plot_model(model_d, 'GAN-discriminator.png', show_shapes=True, show_layer_names=True)

model = gan(input_shape=(100,), model_g=model_g, model_d=model_d)
model.compile(optimizer=optimizer, loss='binary_crossentropy')

model.build(input_shape=(100,))
model.summary()
keras.utils.plot_model(model, 'GAN.png', show_shapes=True, show_layer_names=True)

for epoch in range(epochs):
x = x.shuffle(np.random.randint(0, 10000))
x_db = iter(x)

for step, real_image in enumerate(x_db):
noise = np.random.normal(0, 1, (real_image.shape[0], 100))
fake_image = model_g(noise)

real_dacc(np.ones((real_image.shape[0], 1)), model_d(real_image))
fake_dacc(np.zeros((real_image.shape[0], 1)), model_d(fake_image))
gacc(np.ones((real_image.shape[0], 1)), model(noise))

real_dloss = model_d.train_on_batch(real_image, np.ones((real_image.shape[0], 1)))
fake_dloss = model_d.train_on_batch(fake_image, np.zeros((real_image.shape[0], 1)))
gloss = model.train_on_batch(noise, np.ones((real_image.shape[0], 1)))

if step % 20 == 0:
print('epoch = {}, step = {}, real_dacc = {}, fake_dacc = {}, gacc = {}'.format(epoch, step, real_dacc.result(), fake_dacc.result(), gacc.result()))
real_dacc.reset_states()
fake_dacc.reset_states()
gacc.reset_states()
fake_data = np.random.normal(0, 1, (100, 100))
fake_image = model_g(fake_data)
save_picture(fake_image.numpy(), save_path + '\\epoch{}_step{}.jpg'.format(epoch, step), 10)

gan

模型运行结果

gan

小技巧

  1. 图像输入可以先将其归一化到0-1之间或者-1-1之间,因为网络的参数一般都比较小,所以归一化后计算方便,收敛较快。
  2. 注意其中的一些维度变换和numpytensorflow常用操作,否则在阅读代码时可能会产生一些困难。
  3. 可以设置一些权重的保存方式学习率的下降方式早停方式
  4. GAN对于网络结构,优化器参数,网络层的一些超参数都是非常敏感的,效果不好不容易发现原因,这可能需要较多的工程实践经验
  5. 先创建判别器,然后进行compile,这样判别器就固定了,然后创建生成器时,不要训练判别器,需要将判别器的trainable改成False,此时不会影响之前固定的判别器,这个可以通过模型的_collection_collected_trainable_weights属性查看,如果该属性为空,则模型不训练,否则模型可以训练,compile之后,该属性固定,无论后面如何修改trainable,只要不重新compile,都不影响训练。
  6. 因为全都是全连接层,GAN适用于小目标的生成,如果是一个512x512x3的图像的生成,使用全连接层,那么就需要786432个神经元,上一层的神经元数目应该更多,设为1048576个,那么这两层之间全连接的参数量为八千多亿个,这是非常不现实的,而且效果也会特别差。

GAN小结

  GAN是一种非常简单的生成式对抗网络,从上图可以看出GAN模型的参数量只有2M,虽然现在GAN网络不是最好的生成式对抗网络,但是其网络对抗思想,对后面的深度学习网络的发展有重要的影响。

-------------本文结束感谢您的阅读-------------
0%